# Module containing all relevant functions for solving infinite-horizon version QMT model
# Employ homogeneity of degree 1 in incoming capital, and solve over outstanding equity (s) space: Ke (s_{t+1}=S_{t+1}/K_{t+1}) as choice for next period
# "Kep" will refer to the choice of S_{t+1}/K_{t}≡s_{t+1}, whereas "Ke" will refer to the state S_{t}/K_{t}≡s_{t}

module QMT_ihzn_1_xirs

using LinearAlgebra, SparseArrays
using CompEcon
using LaTeXStrings
using Plots
using Distributions

export Options, StateSpace, Model, IndSolution, IRFsolution, setup!, manprintln, solve_c, iter_howard_step, solve_valfunc_v, solve_valfunc_v0, valfunc_v, solve_L, solve_IRF, sim_firm_pols, sim_firm_pols_intp, menufun, xi_cdfun, xi_condexpn

# Extra commands for printing etc
function manprintln(s::String, filename::String)
    println(s)
    read_tmp = read(filename, String)
    write(filename, read_tmp * s * "\n")
end

# Computation options
mutable struct Options
    # Tolerances, iterations
    Nbell::Integer      # Number of Bellman (Contraction) iterations
    Nhow::Integer      # Number of Howard steps
    first_how::Integer  # First Bellman step after which run Howard
    tolc::Float64       # Tolerance on value functions
    itermaxL::Integer   # Maximum iterations to find stationary dist L
    tolL::Float64       # Tolerance on L

    # Print / plot
    prnt::String        # Print out c-solution convergence (the coefs on value function)
    sfigs::String       # Save all figures plotted
    plotc::String       # Plot policies when solving for c
    plotSD::String      # Plot stationary distribution

    # Related to Sobol calibration
    flag_cconv::Integer     # Idicate whether SS solution for c converged
    n_job::Integer          # The number of the job in a given "calibX_" run
end

# Statespace parameters
mutable struct StateSpace
    n::Array{Integer,1}     # Number of nodes in each (m,Ke) dimension (for interpolating v)
    nf::Array{Integer,1}    # Number of points for (K,m,Ke) in histogram L
    np::Array{Integer, 1}   # Number of points on choice grids for (i, y, Kep)
    curv::Array{Float64,1}  # Curvature for (m,Ke) grids (1 is no curvature)
    curvf::Array{Float64,1} # Curvature for fine (K,m,Ke) grids for histogram (1 is no curvature)
    curvp::Array{Float64,1} # Curvature for choice grids for (i, y, Kep) (1 is no curvature)
    spliorder::Array{Integer,1} # Order of spline for (m,Ke): 1 or 3

    # Grid limits
    mmin::Float64           # Lower bound on m state in v domain
    mmax::Float64           # Upper bound on m state in v domain
    Kemin::Float64          # Lower bound on Ke state in v domain
    Kemax::Float64          # Upper bound on Ke state in v domain
    imin::Float64           # Lower bound on i choice
    imax::Float64           # Upper bound on i choice
    ymin::Float64           # Lower bound on y (y/K) choice
    ymax::Float64           # Upper bound on y (y/K) choice
    Kepmin::Float64         # Lower bound on Kep (Ke'/K) choice
    Kepmax::Float64         # Upper bound on Kep (Ke'/K) choice
    Kminf::Float64          # Lower bound on K in fine GRID for histogram
    Kmaxf::Float64          # Upper bound on K in fine GRID for histogram
    mminf::Float64          # Lower bound on m/K in fine GRID for histogram
    mmaxf::Float64          # Upper bound on m/K in fine GRID for histogram
    Keminf::Float64         # Lower bound on Ke/K in fine GRID for histogram
    Kemaxf::Float64         # Upper bound on Ke/K in fine GRID for histogram

    # Added below
    igrid::Array{Float64,1}                 # Grid for choices of I/K
    ygrid::Array{Float64,1}                 # Grid for choices of y
    Kepgrid::Array{Float64,1}               # Grid for choices of Ke'/K
    Kgridf::Array{Float64,1}                # Fine grid for K
    mgridf::Array{Float64,1}                # Fine grid for m/K
    Kegridf::Array{Float64,1}               # Fine grid for Ke/K
    sf::Array{Float64,2}                    # Fine composite grid (K,m,Ke)
    Nsf::Int64                              # Length of fine composite grid (K,m,Ke)
    sf0::Array{Float64,2}                   # Fine composite grid (m,Ke)
    Nsf0::Int64                             # Length of fine composite grid (m,Ke)
    Q::SparseMatrixCSC{Float64,Int64}       # Transition matrix for firm states
    Phi_v::SparseMatrixCSC{Float64,Int64}   # Basis matrix of Ev interpolant
    mgrid0::Array{Float64,1}                # To save initial grid for computing basis matrices b/c cubic spline extends grid
    mgrid::Array{Float64,1}
    Kegrid0::Array{Float64,1}               # To save initial grid for computing basis matrices b/c cubic spline extends grid
    Kegrid::Array{Float64,1}
    Nm::Int64                               # Length of m-grid
    NKe::Int64                              # Length of Ke-grid
    fspace::Dict{Symbol,Any}                # Container for function space for approximating V
    s::Array{Float64,2}                     # Composite grid (m,Ke) combination
    Ns::Int64                               # Length of composite grid
    # Add "composite grid" for choices and basis matrix for v interpolant evaluated at the mp implied -- to potentially be re-evaluated for each solution problem
    sc::Array{Float64,2}                    # Composite grid for choices (i,y,Kep) combination
    Nsc::Int64                              # Length of composite grid for choices
    sc0::Array{Float64,2}                   # Composite grid for choices (i,y) combination, when do not issue
    Nsc0::Int64                             # Length of composite grid for choices, when do not issue
    Phi_vc::SparseMatrixCSC{Float64,Int64}  # Basis matrix of Ev interpolant, evaluated at mp's implied by sc -- NOT REALLY USED ANYWHERE!
end

# Model parameters
mutable struct Model
    beta::Float64           # Discount factor
    eps_e::Float64          # Entrepreneur's utility from dividends
    rK::Float64             # Dividend productivity of capital
    a::Float64              # CM productivity of capital
    nu::Float64             # Multiplicative coefficient on investment costs
    eta::Float64            # Exponent on investment costs
    delta::Float64          # Depreciation rate
    delta_0::Float64        # The point around which convex adjustment costs are centered
    pi::Float64             # Survival rate of entrepreneurs
    psi::Float64            # Steady state equity price
    rm::Float64             # Return on savings
    m_0::Float64             # Entrants' m/K ratio
    prob_msh::Float64       # Probability of shock to m'/K'
    x_msh::Float64          # Size of shock m'/K'
    Omega::Float64          # Steady state multiplier on utility cost of issued equity
    xi_ub::Float64          # Upper bound on xi U distribution

    name::String            # Model name for saving files
end

# Container for individual firm solution, on arbitrary grid s (full, 3 dimensional)
mutable struct IndSolution
    val::Array{Float64,1}               # Value function value
    i::Array{Float64,1}                 # I/K ratio
    mp::Array{Float64,1}                # m'/K' ratio
    Kep::Array{Float64,1}               # Ke'/K ratio
    ie::Array{Float64,1}                # Ie/K ratio
    y::Array{Float64,1}                 # Current CM consumption for entrepreneur (y)
    cix::Array{Int64,1}                 # Location of the choice-combination in the composite choice grid sc OR sc0 -- careful with distinguishing!
end

# Container for full solution, including unconstrained state indicator, constrained states, etc
mutable struct FirmSolution
    c_vec::Array{Float64,1}       # Coefficients for interpolant of Ev in (Nm X NKe) vector
end

# Container for firm's solution of transition path after MIT shock
mutable struct IRFSolution
    psi_col::Array{Float64,1}       # Path of psi
    psi_ss::Float64                 # Also save the SS psi, understood to follow in T+1
    c_col1::Array{Float64,2}         # Container of coefficients for Ev (Ns X T)
    i_col1::Array{Float64,2}         # Container of choices of i (Ns X T)
    mp_col1::Array{Float64,2}        # Container of choices of mp (Ns X T)
    Kep_col1::Array{Float64,2}       # Container of choices of Kep (Ns X T)
    ie_col1::Array{Float64,2}        # Container of choices of ie (Ns X T)
    y_col1::Array{Float64,2}         # Container of choices of y (Ns X T)
    cix_col1::Array{Float64,2}       # Container of choice-combination location in the composite grid sc (Ns X T)
    i_col0::Array{Float64,2}        # Container of choices of i (Ns X T), conditional on no issuance
    mp_col0::Array{Float64,2}       # Container of choices of mp (Ns X T), conditional on no issuance
    Kep_col0::Array{Float64,2}      # Container of choices of Kep (Ns X T), conditional on no issuance
    ie_col0::Array{Float64,2}       # Container of choices of ie (Ns X T), conditional on no issuance
    y_col0::Array{Float64,2}        # Container of choices of y (Ns X T), conditional on no issuance
    cix_col0::Array{Float64,2}      # Container of choice-combination location in the composite grid sc0 (Ns X T), conditional on no issuance
    xi_col::Array{Float64,2}        # Container of cost cutoffs (Ns X T) for issuance
end

# Container for firm simulation results -- ONE firm
mutable struct FirmSimResults
    i_col::Array{Float64,1}        # Vector (T) of firm's i path
    mp_col::Array{Float64,1}       # Vector (T) of firm's mp path
    Kep_col::Array{Float64,1}      # Vector (T) of firm's Kep path
    ie_col::Array{Float64,1}       # Vector (T) of firm's ie path
    y_col::Array{Float64,1}        # Vector (T) of firm's y path
    Ei_col::Array{Float64,1}       # Vector (T) of firm's extensive margin equity issuance path
    Ei_prob_col::Array{Float64,1}  # Vector (T) of firm's extensive margin equity issuance PROBABILITY
end


# Set up the types
function Options(; Nbell::Integer=6, Nhow::Integer=0, first_how::Integer=2, tolc::Float64=1e-8, itermaxL::Integer=5000, tolL::Float64=1e-10, prnt::String="Y", sfigs::String="N", plotc::String="N", plotSD::String="N", flag_cconv::Integer=0, n_job::Integer=999)

    opts = Options(Nbell, Nhow, first_how, tolc, itermaxL, tolL, prnt, sfigs, plotc, plotSD, flag_cconv, n_job)

    return opts
end

function StateSpace(; n::Array{Int64,1}=[4,4], nf::Array{Int64,1}=[10,10,10], np::Array{Int64,1}=[10,10,10], curv::Array{Float64,1}=[1.0,1.0], curvf::Array{Float64,1}=[1.0,1.0,1.0], curvp::Array{Float64,1}=[1.0,1.0,1.0], spliorder::Array{Int64,1}=[1,1], mmin::Float64=0.0, mmax::Float64=1.0, Kemin::Float64=0.0, Kemax::Float64=1.0, imin::Float64=0.0, imax::Float64=1.0, ymin::Float64=0.0, ymax::Float64=1.0, Kepmin::Float64=0.0, Kepmax::Float64=1.0, Kminf::Float64=0.1, Kmaxf::Float64=5.0, mminf::Float64=0.0, mmaxf::Float64=5.0, Keminf::Float64=0.0, Kemaxf::Float64=1.0, igrid=Array{Float64}(undef, 0), ygrid=Array{Float64}(undef, 0), Kepgrid=Array{Float64}(undef, 0), Kgridf=Array{Float64}(undef, 0), mgridf=Array{Float64}(undef, 0), Kegridf=Array{Float64}(undef, 0), sf=Array{Float64}(undef, 0,0), Nsf=0, sf0=Array{Float64}(undef, 0,0), Nsf0=0,  Q=spzeros(0,0), Phi_v=spzeros(0,0), mgrid0=Array{Float64}(undef, 0), mgrid=Array{Float64}(undef, 0), Kegrid0=Array{Float64}(undef, 0), Kegrid=Array{Float64}(undef, 0), Nm=0, NKe=0, fspace=Dict{Symbol,Any}(), s=Array{Float64}(undef, 0,0), Ns=0, sc=Array{Float64}(undef, 0,0), Nsc=0, sc0=Array{Float64}(undef, 0,0), Nsc0=0, Phi_vc=spzeros(0,0))

    sspace = StateSpace(n, nf, np, curv, curvf, curvp, spliorder, mmin, mmax, Kemin, Kemax, imin, imax, ymin, ymax, Kepmin, Kepmax, Kminf, Kmaxf, mminf, mmaxf, Keminf, Kemaxf, igrid, ygrid, Kepgrid, Kgridf, mgridf, Kegridf, sf, Nsf, sf0, Nsf0, Q, Phi_v, mgrid0, mgrid, Kegrid0, Kegrid, Nm, NKe, fspace, s, Ns, sc, Nsc, sc0, Nsc0, Phi_vc)

    # Added below
    return sspace
end

function Model(; beta::Float64=0.99, eps_e::Float64=1.0, rK::Float64=0.1, a::Float64=0.0, nu::Float64=0.1, eta::Float64=2.0, delta::Float64=0.025, delta_0::Float64=0.0, pi::Float64=1.0, psi::Float64=1.0, rm::Float64=1.0/0.99-1, m_0::Float64=0.3, prob_msh::Float64=0.0, x_msh::Float64=0.0, Omega::Float64=0.0, xi_ub::Float64=0.0, name::String="mod")

    modl = Model(beta, eps_e, rK, a, nu, eta, delta, delta_0, pi, psi, rm, m_0, prob_msh, x_msh, Omega, xi_ub, name)

    return modl
end

# Model setup
function setup!(modl, sspace, opts)

    # State space for endogenous variables (K,m,Ke)
    Nm          = sspace.n[1]
    NKe         = sspace.n[2]

    curv        = sspace.curv
    curvf       = sspace.curvf
    curvp       = sspace.curvp

    spliorder   = sspace.spliorder
    mgrid       = range(sspace.mmin^curv[1], sspace.mmax^curv[1], length=Nm).^(1.0/curv[1])
    mgrid0      = mgrid     # Save for computing basis matrices (cubic spline extends grid)
    Kegrid      = range(sspace.Kemin^curv[2], sspace.Kemax^curv[2], length=NKe).^(1.0/curv[2])
    Kegrid0     = Kegrid     # Save for computing basis matrices (cubic spline extends grid)

    # Function space and nodes (fspace adds knot points for cubic splines)
    fspace      = fundef([:spli, mgrid, 0, sspace.spliorder[1]], [:spli, Kegrid, 0, sspace.spliorder[2]])

    s, sgrid    = funnode(fspace)
    Ns          = size(s,1)

    # Reconstruct grids after fspace added points for the spline (two knot points for cubic spline)
    mgrid   = sgrid[1]
    Nm      = size(mgrid,1)
    Kegrid  = sgrid[2]
    NKe     = size(Kegrid,1)

    # Basis matrix
    Phi_v   = funbase(fspace, s)

    # Construct grids for choices
    igrid           = range(sspace.imin^curvp[1], sspace.imax^curvp[1], length=sspace.np[1]).^(1.0/curvp[1])
    ygrid           = range(sspace.ymin^curvp[2], sspace.ymax^curvp[2], length=sspace.np[2]).^(1.0/curvp[2])
    Kepgrid          = range(sspace.Kepmin^curvp[3], sspace.Kepmax^curvp[3], length=sspace.np[3]).^(1.0/curvp[3])
    # Construct fine grid for histogram
    Kgridf          = range(sspace.Kminf^curvf[1], sspace.Kmaxf^curvf[1], length=sspace.nf[1]).^(1.0/curvf[1])
    mgridf          = range(sspace.mminf^curvf[2], sspace.mmaxf^curvf[2], length=sspace.nf[2]).^(1.0/curvf[2])
    Kegridf         = range(sspace.Keminf^curvf[3], sspace.Kemaxf^curvf[3], length=sspace.nf[3]).^(1.0/curvf[3])

    NKf             = sspace.nf[1]
    Nmf             = sspace.nf[2]
    NKef            = sspace.nf[3]
    sf              = gridmake(Kgridf,mgridf,Kegridf)
    Nsf             = size(sf,1)
    sf0             = gridmake(mgridf,Kegridf)
    Nsf0            = size(sf0,1)
    # Make composite grid for choices as well
    sc              = gridmake(igrid,ygrid,Kepgrid)
    Nsc             = size(sc,1)
    # Make composite grid for choices as well, conditional on no issuance
    sc0             = gridmake(igrid,ygrid)
    Nsc0            = size(sc0,1)

    sspace.Kgridf   = Kgridf
    sspace.mgridf   = mgridf
    sspace.Kegridf  = Kegridf
    sspace.igrid    = igrid
    sspace.ygrid    = ygrid
    sspace.Kepgrid   = Kepgrid
    sspace.sf       = sf
    sspace.Nsf      = Nsf
    sspace.sf0      = sf0
    sspace.Nsf0     = Nsf0
    sspace.sc       = sc
    sspace.Nsc      = Nsc
    sspace.sc0      = sc0
    sspace.Nsc0     = Nsc0

    # Compute QZ matrix for approximation of stationary distribution
    sspace.Q        = sparse(ones(NKf*Nmf*NKef))

    # Create one time only basis matrices
    sspace.Phi_v    = Phi_v
    # If use linear interpolation across all dimensions, impose exact sparse identity for Phi
    if sspace.spliorder == [1,1]; sspace.Phi_v = sparse(I, Ns, Ns); end

    # Declare additional global variables
    sspace.mgrid0   = mgrid0
    sspace.mgrid    = mgrid
    sspace.Kegrid0  = Kegrid0
    sspace.Kegrid   = Kegrid
    sspace.Nm       = Nm
    sspace.NKe      = NKe
    sspace.fspace   = fspace
    sspace.s        = s
    sspace.Ns       = Ns

    # Compute the omega implied by the parameters
    modl.Omega      = -modl.eps_e * modl.rK * modl.beta*(1.0-modl.delta+modl.a/modl.psi) / (1.0-modl.beta*modl.pi*(1.0-modl.delta+modl.a/modl.psi))

    return modl, sspace
end

# STEADY STATE
# Solve for v
function solve_c(modl, sspace, opts, iKe_plot, im_plot)
    s        = sspace.s
    ns       = size(s,1)

    # Initialize guess for coefficients
    c_old    = 4.0*ones(sspace.Ns)

    flag_cconv, opts.flag_cconv  = false, 0
    t_eqstart   = time_ns()
    # Bellman iteration
    if opts.prnt =="Y"; manprintln("Bellman iterations for v:", "mod_savedout/"*modl.name*"_output.txt"); end
    for citer = 1:opts.Nbell
        # 1. Compute values
        # With issuance
        v1,      = solve_valfunc_v(c_old, modl.psi, modl.rm, s, modl, sspace, opts, false)
        # Without issuance
        v0,      = solve_valfunc_v0(c_old, modl.psi, modl.rm, s, modl, sspace, opts, false)

        # Compute cost cutoffs
        xi_vec   = v1.val - v0.val
        # Probabilities of issuing
        Ei_prob  = xi_cdfun(xi_vec, modl, sspace, opts)

        # Compute the actual expected value, over investment oppy shock
        Ev       = Ei_prob .* (v1.val - xi_condexpn(xi_vec, modl, sspace, opts)) + (ones(ns)-Ei_prob) .* (v0.val)

        # 2. Update c
        c_new    = sspace.Phi_v \ Ev
        # 3. Compute distance and update
        dc       = norm(c_new-c_old)

        # Indices over which to plot
        # iKe_plot = [1, 2, 3, sspace.NKe]
        # im_plot = [1, sspace.Nm]
        if opts.prnt == "Y" && mod(citer,1) == 0
            manprintln("$citer\t dc = $dc, \t Time: $((time_ns()-t_eqstart)/1.0e9)", "mod_savedout/"*modl.name*"_output.txt")
        end
        if opts.plotc == "Y" && mod(citer,1) == 0
            # Issuers' policies, as function of m
            plt_inds_temp = ((iKe_plot[1]-1)*sspace.Nm+1):(iKe_plot[1]*sspace.Nm)
            Plots.plot(sspace.mgrid, [v1.val[plt_inds_temp] v1.i[plt_inds_temp] v1.mp[plt_inds_temp] v1.ie[plt_inds_temp] v1.y[plt_inds_temp] v1.Kep[plt_inds_temp]], labels=[latexstring("\$(a^{b},$(round(sspace.Kegrid[iKe_plot[1]],digits=2)))\$") :none :none :none :none :none], xlabel=L"a^{b}", ylabel=[latexstring("\$v\$") latexstring("\$x\$") latexstring("\$a^{b'}\$") latexstring("\$e^{b}\$") latexstring("\$y^{b}\$") latexstring("\$s'\$")], layout=(3,2), legend=[:topright :none :none :none :none :none])
            for iKK in iKe_plot[2:end]
                plt_inds_temp = ((iKK-1)*sspace.Nm+1):(iKK*sspace.Nm)
                Plots.plot!(sspace.mgrid, [v1.val[plt_inds_temp] v1.i[plt_inds_temp] v1.mp[plt_inds_temp] v1.ie[plt_inds_temp] v1.y[plt_inds_temp] v1.Kep[plt_inds_temp]], labels=[latexstring("\$(a^{b},$(round(sspace.Kegrid[iKK],digits=2)))\$") :none :none :none :none :none], xlabel=L"a^{b}", layout=(3,2))
            end
            savefig("mod_savedout/conv_figs/conv_v1_$(citer).png")
        end
        if opts.plotc == "Y" && mod(citer,1) == 0
            # Issuers' policies, as function of Ke
            plt_inds_temp = im_plot[1]:sspace.Nm:((sspace.Nm-1)*sspace.NKe+im_plot[1])
            Plots.plot(sspace.Kegrid, [v1.val[plt_inds_temp] v1.i[plt_inds_temp] v1.mp[plt_inds_temp] v1.ie[plt_inds_temp] v1.y[plt_inds_temp] v1.Kep[plt_inds_temp]], labels=[latexstring("\$($(round(sspace.mgrid[im_plot[1]],digits=2)),s)\$") :none :none :none :none :none], xlabel=L"s", ylabel=[latexstring("\$v\$") latexstring("\$x\$") latexstring("\$a^{b'}\$") latexstring("\$e^{b}\$") latexstring("\$y^{b}\$") latexstring("\$s'\$")], layout=(3,2), legend=[:topright :none :none :none :none :none])
            for imm in im_plot[2:end]
                plt_inds_temp = imm:sspace.Nm:((sspace.Nm-1)*sspace.NKe+imm)
                Plots.plot!(sspace.Kegrid, [v1.val[plt_inds_temp] v1.i[plt_inds_temp] v1.mp[plt_inds_temp] v1.ie[plt_inds_temp] v1.y[plt_inds_temp] v1.Kep[plt_inds_temp]], labels=[latexstring("\$($(round(sspace.mgrid[imm],digits=2)),s)\$") :none :none :none :none :none], xlabel=L"s", ylabel=[latexstring("\$v\$") latexstring("\$x\$") latexstring("\$a^{b'}\$") latexstring("\$e^{b}\$") latexstring("\$y^{b}\$") latexstring("\$s'\$")], layout=(3,2), legend=[:topright :none :none :none :none :none])
            end
            savefig("mod_savedout/conv_figs/conv_v1Ke_$(citer).png")
        end
        if opts.plotc == "Y" && mod(citer,1) == 0
            # Non-issuers' policies, as function of m
            plt_inds_temp = ((iKe_plot[1]-1)*sspace.Nm+1):(iKe_plot[1]*sspace.Nm)
            Plots.plot(sspace.mgrid, [v0.val[plt_inds_temp] v0.i[plt_inds_temp] v0.mp[plt_inds_temp] v0.ie[plt_inds_temp] v0.y[plt_inds_temp] v0.Kep[plt_inds_temp]], labels=[latexstring("\$(a^{b},$(round(sspace.Kegrid[iKe_plot[1]],digits=2)))\$") :none :none :none :none :none], xlabel=L"a^{b}", ylabel=[latexstring("\$v\$") latexstring("\$x\$") latexstring("\$a^{b'}\$") latexstring("\$e^{b}\$") latexstring("\$y^{b}\$") latexstring("\$s'\$")], layout=(3,2), legend=[:topright :none :none :none :none :none])
            for iKK in iKe_plot[2:end]
                plt_inds_temp = ((iKK-1)*sspace.Nm+1):(iKK*sspace.Nm)
                Plots.plot!(sspace.mgrid, [v0.val[plt_inds_temp] v0.i[plt_inds_temp] v0.mp[plt_inds_temp] v0.ie[plt_inds_temp] v0.y[plt_inds_temp] v0.Kep[plt_inds_temp]], labels=[latexstring("\$(a^{b},$(round(sspace.Kegrid[iKK],digits=2)))\$") :none :none :none :none :none], xlabel=L"a^{b}", layout=(3,2))
            end
            savefig("mod_savedout/conv_figs/conv_v0_$(citer).png")
        end
        if opts.plotc == "Y" && mod(citer,1) == 0
            # Non-issuers' policies, as function of Ke
            plt_inds_temp = im_plot[1]:sspace.Nm:((sspace.Nm-1)*sspace.NKe+im_plot[1])
            Plots.plot(sspace.Kegrid, [v0.val[plt_inds_temp] v0.i[plt_inds_temp] v0.mp[plt_inds_temp] v0.ie[plt_inds_temp] v0.y[plt_inds_temp] v0.Kep[plt_inds_temp]], labels=[latexstring("\$($(round(sspace.mgrid[im_plot[1]],digits=2)),s)\$") :none :none :none :none :none], xlabel=L"s", ylabel=[latexstring("\$v\$") latexstring("\$x\$") latexstring("\$a^{b'}\$") latexstring("\$e^{b}\$") latexstring("\$y^{b}\$") latexstring("\$s'\$")], layout=(3,2), legend=[:topright :none :none :none :none :none])
            for imm in im_plot[2:end]
                plt_inds_temp = imm:sspace.Nm:((sspace.Nm-1)*sspace.NKe+imm)
                Plots.plot!(sspace.Kegrid, [v0.val[plt_inds_temp] v0.i[plt_inds_temp] v0.mp[plt_inds_temp] v0.ie[plt_inds_temp] v0.y[plt_inds_temp] v0.Kep[plt_inds_temp]], labels=[latexstring("\$($(round(sspace.mgrid[imm],digits=2)),s)\$") :none :none :none :none :none], xlabel=L"s", ylabel=[latexstring("\$v\$") latexstring("\$x\$") latexstring("\$a^{b'}\$") latexstring("\$e^{b}\$") latexstring("\$y^{b}\$") latexstring("\$s'\$")], layout=(3,2), legend=[:topright :none :none :none :none :none])
            end
            savefig("mod_savedout/conv_figs/conv_v0Ke_$(citer).png")
        end
        if opts.plotc == "Y" && mod(citer,1) == 0
            # Issuing probability, as function of m
            plt_inds_temp = ((iKe_plot[1]-1)*sspace.Nm+1):(iKe_plot[1]*sspace.Nm)
            Plots.plot(sspace.mgrid, Ei_prob[plt_inds_temp], label=latexstring("\$(a^{b},$(round(sspace.Kegrid[iKe_plot[1]],digits=2)))\$"), xlabel=L"a^{b}", ylabel=L"P_{I}")
            for iKK in iKe_plot[2:end]
                plt_inds_temp = ((iKK-1)*sspace.Nm+1):(iKK*sspace.Nm)
                Plots.plot!(sspace.mgrid, Ei_prob[plt_inds_temp], label=latexstring("\$(a^{b},$(round(sspace.Kegrid[iKK],digits=2)))\$"), xlabel=L"a^{b}")
            end
            savefig("mod_savedout/conv_figs/conv_EiP_$(citer).png")
        end
        if opts.plotc == "Y" && mod(citer,1) == 0
            # Issuing probability, as function of Ke
            plt_inds_temp = im_plot[1]:sspace.Nm:((sspace.Nm-1)*sspace.NKe+im_plot[1])
            Plots.plot(sspace.Kegrid, Ei_prob[plt_inds_temp], label=latexstring("\$($(round(sspace.mgrid[im_plot[1]],digits=2)),s)\$"), xlabel=L"s", ylabel=L"P_{I}")
            for imm in im_plot[2:end]
                plt_inds_temp = imm:sspace.Nm:((sspace.Nm-1)*sspace.NKe+imm)
                Plots.plot!(sspace.Kegrid, Ei_prob[plt_inds_temp], label=latexstring("\$($(round(sspace.mgrid[imm],digits=2)),s)\$"), xlabel=L"s")
            end
            savefig("mod_savedout/conv_figs/conv_EiPKe_$(citer).png")
        end
        if opts.plotc == "Y" && mod(citer,1) == 0
            # EV, as function of m
            plt_inds_temp = ((iKe_plot[1]-1)*sspace.Nm+1):(iKe_plot[1]*sspace.Nm)
            Plots.plot(sspace.mgrid, Ev[plt_inds_temp], label=latexstring("\$(a^{b},$(round(sspace.Kegrid[iKe_plot[1]],digits=2)))\$"), xlabel=L"a^{b}", legend=:bottomright, ylabel=L"EJ")
            for iKK in iKe_plot[2:end]
                plt_inds_temp = ((iKK-1)*sspace.Nm+1):(iKK*sspace.Nm)
                Plots.plot!(sspace.mgrid, Ev[plt_inds_temp], label=latexstring("\$(a^{b},$(round(sspace.Kegrid[iKK],digits=2)))\$"), xlabel=L"a^{b}")
            end
            savefig("mod_savedout/conv_figs/conv_Ev_$(citer).png")
        end
        if opts.plotc == "Y" && mod(citer,1) == 0
            # EV, as function of Ke
            plt_inds_temp = im_plot[1]:sspace.Nm:((sspace.Nm-1)*sspace.NKe+im_plot[1])
            Plots.plot(sspace.Kegrid, Ev[plt_inds_temp], label=latexstring("\$($(round(sspace.mgrid[im_plot[1]],digits=2)),s)\$"), xlabel=L"s", legend=:topright, ylabel=L"EJ")
            for imm in im_plot[2:end]
                plt_inds_temp = imm:sspace.Nm:((sspace.Nm-1)*sspace.NKe+imm)
                Plots.plot!(sspace.Kegrid, Ev[plt_inds_temp], label=latexstring("\$($(round(sspace.mgrid[imm],digits=2)),s)\$"), xlabel=L"s")
            end
            savefig("mod_savedout/conv_figs/conv_EvKe_$(citer).png")
        end

        if dc < opts.tolc; flag_cconv = true; opts.flag_cconv=1; end

        if flag_cconv; manprintln("Bellman iterations -> v convergence in steps: $citer, time: $((time_ns()-t_eqstart)/1.0e9).", "mod_savedout/"*modl.name*"_output.txt"); break;  end
        if !flag_cconv && citer==opts.Nbell; manprintln("Finished $citer Bellman iterations without v convergence, distance: $dc, time: $((time_ns()-t_eqstart)/1.0e9).", "mod_savedout/"*modl.name*"_output.txt"); end

        # Update interpolant
        c_old    = c_new

        # Run Howard step
        if citer == opts.first_how
            manprintln("Running $(opts.Nhow) Howard steps below ...", "mod_savedout/"*modl.name*"_output.txt")
        end
        if citer >= opts.first_how
            c_how   = iter_howard_step(c_old, modl.psi, modl.rm, v1.cix, v0.cix, modl, sspace, opts)
            # Impose as new interpolant
            c_old   = c_how
        end
    end

    # Simply return the converged coefficients / values on grid
    return vec(c_old)
end

# Run Howard steps
function iter_howard_step(c_in, psi_cur, rm_cur, cix_c1, cix_c0, modl, sspace, opts)
    # manprintln("Running $(opts.Nhow) Howard steps:", "mod_savedout/"*modl.name*"_output.txt")
    s   = sspace.s
    ns  = size(s,1)
    sc   = sspace.sc
    sc0  = sspace.sc0

    # Given policies in cix, immediately compute implied mp_temp and KepdK_sc_temp on s-grid
    # Given inv oppy
    mp_temp1 = (psi_cur .* (sc[cix_c1,3] .- (1.0-modl.delta+modl.a/psi_cur)*s[:,2] ) + (1.0+rm_cur) * s[:,1] + modl.a*ones(ns) - sc[cix_c1,2] - menufun("adj_cost", s[:,1], sc[cix_c1,1], sc[cix_c1,2], sc[cix_c1,3], modl, sspace, opts)) ./ ((1.0-modl.delta)*ones(ns) + sc[cix_c1,1]) # The mp choices, given the choices in the composite choice grid
    # Also, conditional on receiving m' shock
    mp_temp_msh1      = mp_temp1 + modl.x_msh*ones(ns)
    KepdK_temp1       = sc[cix_c1,3]./(1.0-modl.delta .+ sc[cix_c1,1])  # The Kep/K' choices, given the choices in the composite choice grid (CONSTANT across incoming states)

    # Phi_vc_temp1,     = splibase(sspace.mgrid0, 0, sspace.spliorder[1], mp_temp1)
    # Phi_vc_temp_msh1, = splibase(sspace.mgrid0, 0, sspace.spliorder[1], mp_temp_msh1)

    # Same for no issuance
    Kep_temp0    = (1.0-modl.delta+modl.a/psi_cur)*s[:,2]
    KepdK_temp0  = Kep_temp0./(1.0-modl.delta .+ sc[cix_c0,1])

    mp_temp0 = (psi_cur .* zeros(size(cix_c0,1)) + (1.0+rm_cur) * s[:,1] + modl.a*ones(ns) - sc0[cix_c0,2] - menufun("adj_cost", s[:,1], sc0[cix_c0,1], sc0[cix_c0,2], Kep_temp0, modl, sspace, opts)) ./ ((1.0-modl.delta)*ones(ns) + sc0[cix_c0,1]) # The mp choices, given the choices in the composite choice grid
    # Also, conditional on receiving m' shock
    mp_temp_msh0      = mp_temp0 + modl.x_msh*ones(ns)
    # Phi_vc_temp0,     = splibase(sspace.mgrid0, 0, sspace.spliorder[1], mp_temp0)
    # Phi_vc_temp_msh0, = splibase(sspace.mgrid0, 0, sspace.spliorder[1], mp_temp_msh0)

    c_old = c_in
    # Loop to evaluate policies
    for citer = 1:opts.Nhow
        # Update continuation value
        # Vcont_temp_vec    = Phi_vc_temp * c_old
        Vcont_temp_vec1    = (1.0-modl.prob_msh)*funeval(c_old, sspace.fspace, [mp_temp1 KepdK_temp1])[1] + modl.prob_msh*funeval(c_old, sspace.fspace, [mp_temp_msh1 KepdK_temp1])[1]
        Vcont_temp_vec0    = (1.0-modl.prob_msh)*funeval(c_old, sspace.fspace, [mp_temp0 KepdK_temp0])[1] + modl.prob_msh*funeval(c_old, sspace.fspace, [mp_temp_msh0 KepdK_temp0])[1]
        # Evaluate policies
        v_val1  = Array{Float64}(undef, ns)
        v_val0  = Array{Float64}(undef, ns)
        for ii = 1:ns
            # Evaluate v given policies
            v_val1[ii]      = valfunc_v(psi_cur, s[ii]', sc[cix_c1[ii],1], sc[cix_c1[ii],2], sc[cix_c1[ii],3], mp_temp1[ii], Vcont_temp_vec1[ii], modl, sspace, opts)
            v_val0[ii]      = valfunc_v(psi_cur, s[ii]', sc0[cix_c0[ii],1], sc0[cix_c0[ii],2], Kep_temp0[ii], mp_temp0[ii], Vcont_temp_vec0[ii], modl, sspace, opts)
        end

        # Compute cost cutoffs
        xi_vec   = v_val1 - v_val0
        # Probabilities of issuing
        Ei_prob  = xi_cdfun(xi_vec, modl, sspace, opts)
        # Compute the actual expected value, over investment oppy shock
        Ev_val   = Ei_prob .* (v_val1 - xi_condexpn(xi_vec, modl, sspace, opts)) + (ones(ns)-Ei_prob) .* (v_val0)

        # Update interpolant
        c_old     = sspace.Phi_v \ Ev_val
    end

    return c_old
end

# Solve v value function, given continuation "c" and current psi
function solve_valfunc_v(c, psi_cur, rm_cur, s, modl, sspace, opts, calc_jac)
    ns = size(s,1)
    sc = sspace.sc
    Nsc = sspace.Nsc

    # Iterate over all the states
    # Containers for policies
    v, i, mp, Kep, ie, y, cix = Array{Float64}(undef, ns), Array{Float64}(undef, ns), Array{Float64}(undef, ns), Array{Float64}(undef, ns), Array{Float64}(undef, ns), Array{Float64}(undef, ns), Array{Int64}(undef, ns)

    KepdK_sc_temp = sc[:,3]./(1.0-modl.delta .+ sc[:,1])  # The Kep/K' choices, given the choices in the composite choice grid (CONSTANT across incoming states)

    # Iterate over the states
    for ii = 1:ns
        s_temp            = s[ii,:]'

        # Pre-compute continuation value function interpolation base matrix and immediately evaluate here, given c
        mp_sc_temp        = (psi_cur .* (sc[:,3] .- (1.0-modl.delta+modl.a/psi_cur)*s_temp[2] ) + (1.0+modl.rm) * s_temp[1]*ones(Nsc) + modl.a*ones(Nsc) - sc[:,2] - menufun("adj_cost", s_temp[1]*ones(Nsc), sc[:,1], sc[:,2], sc[:,3], modl, sspace, opts)) ./ ((1.0-modl.delta)*ones(Nsc) + sc[:,1]) # The mp choices, given the choices in the composite choice grid
        # Also, conditional on receiving m' shock
        mp_sc_temp_msh    = mp_sc_temp + modl.x_msh*ones(Nsc)

        # Phi_vc_temp,      = splibase(sspace.mgrid0, 0, sspace.spliorder[1], mp_sc_temp)
        # Phi_vc_temp_msh,  = splibase(sspace.mgrid0, 0, sspace.spliorder[1], mp_sc_temp_msh)
        Vcont_temp_nomsh  = funeval(c, sspace.fspace, [mp_sc_temp KepdK_sc_temp])[1]
        Vcont_temp_msh    = funeval(c, sspace.fspace, [mp_sc_temp_msh KepdK_sc_temp])[1]
        Vcont_temp_vec    = (1.0-modl.prob_msh)*Vcont_temp_nomsh + modl.prob_msh*Vcont_temp_msh

        max_val = -1e3-1.0
        # Apply simple grid search over composite sc space
        for cix_temp in 1:Nsc
            # Evaluate the objective
            i_temp, y_temp, Kep_temp= sc[cix_temp,:]
            mp_temp                 = mp_sc_temp[cix_temp]
            Vcont_temp_val          = Vcont_temp_vec[cix_temp]
            cur_val                 = valfunc_v(psi_cur, s_temp, i_temp, y_temp, Kep_temp, mp_temp, Vcont_temp_val, modl, sspace, opts)
            if cur_val > max_val
                max_val = cur_val
                v[ii]   = cur_val
                i[ii]   = i_temp
                y[ii]   = y_temp
                Kep[ii] = Kep_temp
                ie[ii]  = Kep_temp - (1.0-modl.delta+modl.a/psi_cur)*s_temp[2]
                mp[ii]  = mp_temp
                cix[ii] = cix_temp
            end
        end
    end

    # Compute jacobian if necessary
    if calc_jac
        jac = sspace.Phi_vu - modl.beta*(1.0-modl.eta) * sspace.Emat_vu * Phi_kpz
    else
        jac = Array{Float64,2}(undef,0,0)
    end

    # Packup output
    v = IndSolution(v, i, mp, Kep, ie, y, cix)

    return v, jac
end

# Exactly repeate solve_valfunc_v, but impose that ie=0 (no issuance)
function solve_valfunc_v0(c, psi_cur, rm_cur, s, modl, sspace, opts, calc_jac)
    ns = size(s,1)
    sc0     = sspace.sc0
    Nsc0    = sspace.Nsc0

    # Iterate over all the states
    # Containers for policies
    v, i, mp, Kep, ie, y, cix = Array{Float64}(undef, ns), Array{Float64}(undef, ns), Array{Float64}(undef, ns), Array{Float64}(undef, ns), Array{Float64}(undef, ns), Array{Float64}(undef, ns), Array{Int64}(undef, ns)


    # Iterate over the states
    for ii = 1:ns
        s_temp            = s[ii,:]'

        # Precompute the Kep/K' choices, given the choices in the composite choice grid (NOT CONSTANT across incoming states because Kep depends on Ke, if not issuing)
        KepdK_sc_temp    = (1.0-modl.delta+modl.a/psi_cur)*s_temp[2]./(1.0-modl.delta .+ sc0[:,1])  # The

        # Pre-compute continuation value function interpolation base matrix and immediately evaluate here, given c
        mp_sc_temp        = (psi_cur .* zeros(Nsc0) + (1.0+modl.rm) * s_temp[1]*ones(Nsc0) + modl.a*ones(Nsc0) - sc0[:,2] - menufun("adj_cost", s_temp[1]*ones(Nsc0), sc0[:,1], sc0[:,2], (1.0-modl.delta+modl.a/psi_cur)*s_temp[2]*ones(Nsc0), modl, sspace, opts)) ./ ((1.0-modl.delta)*ones(Nsc0) + sc0[:,1]) # The mp choices, given the choices in the composite choice grid
        # Also, conditional on receiving m' shock
        mp_sc_temp_msh    = mp_sc_temp + modl.x_msh*ones(Nsc0)

        # Phi_vc_temp,      = splibase(sspace.mgrid0, 0, sspace.spliorder[1], mp_sc_temp)
        # Phi_vc_temp_msh,  = splibase(sspace.mgrid0, 0, sspace.spliorder[1], mp_sc_temp_msh)
        Vcont_temp_nomsh  = funeval(c, sspace.fspace, [mp_sc_temp KepdK_sc_temp])[1]
        Vcont_temp_msh    = funeval(c, sspace.fspace, [mp_sc_temp_msh KepdK_sc_temp])[1]
        Vcont_temp_vec    = (1.0-modl.prob_msh)*Vcont_temp_nomsh + modl.prob_msh*Vcont_temp_msh

        max_val = -1e3-1.0
        # Apply simple grid search over composite sc space
        for cix_temp in 1:Nsc0
            # Evaluate the objective
            Kep_temp                = (1.0-modl.delta+modl.a/psi_cur)*s_temp[2]
            i_temp, y_temp          = sc0[cix_temp,:]
            mp_temp                 = mp_sc_temp[cix_temp]
            Vcont_temp_val          = Vcont_temp_vec[cix_temp]
            cur_val                 = valfunc_v(psi_cur, s_temp, i_temp, y_temp, Kep_temp, mp_temp, Vcont_temp_val, modl, sspace, opts)
            if cur_val > max_val
                max_val = cur_val
                v[ii]   = cur_val
                i[ii]   = i_temp
                y[ii]   = y_temp
                Kep[ii] = Kep_temp
                ie[ii]  = Kep_temp - (1.0-modl.delta+modl.a/psi_cur)*s_temp[2]
                mp[ii]  = mp_temp
                cix[ii] = cix_temp
            end
        end
    end

    # Compute jacobian if necessary
    if calc_jac
        jac = sspace.Phi_vu - modl.beta*(1.0-modl.eta) * sspace.Emat_vu * Phi_kpz
    else
        jac = Array{Float64,2}(undef,0,0)
    end

    # Packup output
    v = IndSolution(v, i, mp, Kep, ie, y, cix)

    return v, jac
end

# Evaluate objective function, given continuation value functoin c, conditional on choices of (i_temp, mp_temp, ie_temp)
# Return only value function values
function valfunc_v(psi_cur, s_cur, i_temp, y_temp, Kep_temp, mp_temp, Vcont_temp_val, modl, sspace, opts)
    # If any constraints are violated, immediately return low, negative value
    if (mp_temp < 0.0) | (Kep_temp > 1.0-modl.delta+i_temp) | (y_temp < 0.0)
        return -1e3
    end

    # Compute value function value
    if size(s_cur,1) == 1
        v_val       = y_temp + modl.beta*modl.eps_e*modl.rK*(1.0-modl.delta+i_temp-Kep_temp) + modl.beta * modl.pi * (1.0-modl.delta+i_temp) * Vcont_temp_val
    else
        v_val       = y_temp + modl.beta*modl.eps_e*modl.rK*(1.0-modl.delta+i_temp-Kep_temp) + modl.beta * modl.pi * (1.0-modl.delta+i_temp) .* Vcont_temp_val
    end

    return v_val
end
function valfunc_v(psi_cur, s_cur, i_temp, y_temp, Kep_temp, mp_temp, Vcont_temp_val, modl, sspace, opts, calc_all)
    # If any constraints are violated, immediately return low, negative value
    if (mp_temp < 0.0) | (Kep_temp > 1.0-modl.delta+i_temp) | (y_temp < 0.0)
        return -1e3, i_temp, mp_temp, Kep_temp, y_temp
    end

    # Compute value function value
    if size(s_cur,1) == 1
        v_val       = y_temp + modl.beta*modl.eps_e*modl.rK*(1.0-modl.delta+i_temp-Kep_temp) + modl.beta * modl.pi * (1.0-modl.delta+i_temp) * Vcont_temp_val
    else
        v_val       = y_temp + modl.beta*modl.eps_e*modl.rK*(1.0-modl.delta+i_temp-Kep_temp) + modl.beta * modl.pi * (1.0-modl.delta+i_temp) .* Vcont_temp_val
    end

    return v_val, i_temp, mp_temp, Kep_temp, y_temp
end


# Solve the firm's problem along a path for "psi".
# Read T from the length of the entered "psi_col".
# Timing convention: the "psi_col" fed in has length "T". It is understood that the value function in "T+1" is the steady state one.
# Solve problem and value function in
function solve_IRF(psi_col, c_sol_ss, modl, sspace, opts, iKe_plot, im_plot)
    # A. Globals
    s       = sspace.s
    ns      = size(s,1)

    # Initialize
    T           = length(psi_col)
    # Save variables to return
    psi_ss      = modl.psi

    # Collections that will contain solution objects along path
    c_col           = Array{Float64}(undef, ns, T) # Simply expected value function values on the whole grid
    i_col1, i_col0      = Array{Float64}(undef, ns, T), Array{Float64}(undef, ns, T) # i choices on the whole grid, conditional on issuance
    mp_col1, mp_col0    = Array{Float64}(undef, ns, T), Array{Float64}(undef, ns, T) # mp choices on the whole grid, conditional on issuance
    Kep_col1, Kep_col0    = Array{Float64}(undef, ns, T), Array{Float64}(undef, ns, T) # Kep choices on the whole grid, conditional on issuance
    ie_col1, ie_col0    = Array{Float64}(undef, ns, T), Array{Float64}(undef, ns, T) # ie choices on the whole grid, conditional on issuance
    y_col1, y_col0      = Array{Float64}(undef, ns, T), Array{Float64}(undef, ns, T) # y choices on the whole grid, conditional on issuance
    cix_col1, cix_col0  = Array{Int64}(undef, ns, T), Array{Int64}(undef, ns, T)   # cix choices on the whole grid, conditional on issuance
    xi_col              = Array{Float64}(undef, ns, T) # The values of issuance cutoff costs

    # Initialize t+1 value function coefficients
    c_next          = c_sol_ss

    t_start = time_ns()
    # Loop backward, solving the firm's problem
    manprintln("Computing firm IRF solution for T=$(T).", "mod_savedout/"*modl.name*"_output.txt")
    for tt = T:-1:1
        manprintln("Period $(tt).", "mod_savedout/"*modl.name*"_output.txt")
        # Solve the firm problem for t, calculating all policies
        v_cur1,   = solve_valfunc_v(c_next, psi_col[tt], modl.rm, s, modl, sspace, opts, false)
        v_cur0,   = solve_valfunc_v0(c_next, psi_col[tt], modl.rm, s, modl, sspace, opts, false)
        i_col1[:,tt], mp_col1[:,tt], Kep_col1[:,tt], ie_col1[:,tt], y_col1[:,tt], cix_col1[:,tt] = v_cur1.i,  v_cur1.mp,  v_cur1.Kep,  v_cur1.ie,  v_cur1.y, v_cur1.cix
        i_col0[:,tt], mp_col0[:,tt], Kep_col0[:,tt], ie_col0[:,tt], y_col0[:,tt], cix_col0[:,tt] = v_cur0.i,  v_cur0.mp,  v_cur0.Kep,  v_cur0.ie,  v_cur0.y, v_cur0.cix

        # Compute cost cutoffs
        xi_vec_cur   = v_cur1.val - v_cur0.val
        xi_col[:,tt] = xi_vec_cur

        # Probabilities of issuing
        Ei_prob_cur  = xi_cdfun(xi_vec_cur, modl, sspace, opts)

        # Compute the actual expected value, over investment oppy shock
        Ev_cur       = Ei_prob_cur .* (v_cur1.val - xi_condexpn(xi_vec_cur, modl, sspace, opts)) + (ones(ns)-Ei_prob_cur) .* (v_cur0.val)

        # Based on current v values, save value function coefs and set as continuation value function
        c_next      = sspace.Phi_v \ Ev_cur
        c_col[:,tt] = c_next

        # Plot the value function and policies, only as functions of m for now
        if opts.plotc == "Y"
            # Issuers' policies, as function of m
            plt_inds_temp = ((iKe_plot[1]-1)*sspace.Nm+1):(iKe_plot[1]*sspace.Nm)
            Plots.plot(sspace.mgrid, [v_cur1.val[plt_inds_temp] v_cur1.i[plt_inds_temp] v_cur1.mp[plt_inds_temp] v_cur1.ie[plt_inds_temp] v_cur1.y[plt_inds_temp] v_cur1.Kep[plt_inds_temp]], labels=[latexstring("\$(a^{b},$(round(sspace.Kegrid[iKe_plot[1]],digits=2)))\$") :none :none :none :none :none], xlabel=L"a^{b}", ylabel=[latexstring("\$v\$") latexstring("\$x\$") latexstring("\$a^{b'}\$") latexstring("\$e^{b}\$") latexstring("\$y^{b}\$") latexstring("\$s'\$")], layout=(3,2), legend=[:topright :none :none :none :none :none])
            for iKK in iKe_plot[2:end]
                plt_inds_temp = ((iKK-1)*sspace.Nm+1):(iKK*sspace.Nm)
                Plots.plot!(sspace.mgrid, [v_cur1.val[plt_inds_temp] v_cur1.i[plt_inds_temp] v_cur1.mp[plt_inds_temp] v_cur1.ie[plt_inds_temp] v_cur1.y[plt_inds_temp] v_cur1.Kep[plt_inds_temp]], labels=[latexstring("\$(a^{b},$(round(sspace.Kegrid[iKK],digits=2)))\$") :none :none :none :none :none], xlabel=L"a^{b}", layout=(3,2))
            end
            savefig("mod_savedout/conv_figs/irf_v1_T$(T)_$(tt).png")
        end
        if opts.plotc == "Y"
            # Non-issuers' policies, as function of m
            plt_inds_temp = ((iKe_plot[1]-1)*sspace.Nm+1):(iKe_plot[1]*sspace.Nm)
            Plots.plot(sspace.mgrid, [v_cur0.val[plt_inds_temp] v_cur0.i[plt_inds_temp] v_cur0.mp[plt_inds_temp] v_cur0.ie[plt_inds_temp] v_cur0.y[plt_inds_temp] v_cur0.Kep[plt_inds_temp]], labels=[latexstring("\$(a^{b},$(round(sspace.Kegrid[iKe_plot[1]],digits=2)))\$") :none :none :none :none :none], xlabel=L"a^{b}", ylabel=[latexstring("\$v\$") latexstring("\$x\$") latexstring("\$a^{b'}\$") latexstring("\$e^{b}\$") latexstring("\$y^{b}\$") latexstring("\$s'\$")], layout=(3,2), legend=[:topright :none :none :none :none :none])
            for iKK in iKe_plot[2:end]
                plt_inds_temp = ((iKK-1)*sspace.Nm+1):(iKK*sspace.Nm)
                Plots.plot!(sspace.mgrid, [v_cur0.val[plt_inds_temp] v_cur0.i[plt_inds_temp] v_cur0.mp[plt_inds_temp] v_cur0.ie[plt_inds_temp] v_cur0.y[plt_inds_temp] v_cur0.Kep[plt_inds_temp]], labels=[latexstring("\$(a^{b},$(round(sspace.Kegrid[iKK],digits=2)))\$") :none :none :none :none :none], xlabel=L"a^{b}", layout=(3,2))
            end
            savefig("mod_savedout/conv_figs/irf_v0_T$(T)_$(tt).png")
        end
        if opts.plotc == "Y"
            # Issuing probability, as function of m
            plt_inds_temp = ((iKe_plot[1]-1)*sspace.Nm+1):(iKe_plot[1]*sspace.Nm)
            Plots.plot(sspace.mgrid, Ei_prob_cur[plt_inds_temp], label=latexstring("\$(a^{b},$(round(sspace.Kegrid[iKe_plot[1]],digits=2)))\$"), xlabel=L"a^{b}", ylabel=L"P_{I}")
            for iKK in iKe_plot[2:end]
                plt_inds_temp = ((iKK-1)*sspace.Nm+1):(iKK*sspace.Nm)
                Plots.plot!(sspace.mgrid, Ei_prob_cur[plt_inds_temp], label=latexstring("\$(a^{b},$(round(sspace.Kegrid[iKK],digits=2)))\$"), xlabel=L"a^{b}")
            end
            savefig("mod_savedout/conv_figs/irf_EiP_T$(T)_$(tt).png")
        end
    end
    manprintln("Computation of firm IRF solution complete in time: $((time_ns()-t_start)/1.0e9).", "mod_savedout/"*modl.name*"_output.txt")

    # Save everything
    irfsoln = IRFSolution(psi_col, psi_ss, c_col, i_col1, mp_col1, Kep_col1, ie_col1, y_col1, cix_col1, i_col0, mp_col0, Kep_col0, ie_col0, y_col0, cix_col0, xi_col)

    return irfsoln
end

# Simulate firm given colletion of shape (ns X T) of policies, fed in as (i_pol_col, mp_pol_col, Kep_pol_col, ie_pol_col, y_pol_col)
# To determine m position, pick the closest point in the grid to the choice implied by "mp_pol_col"
# Maximum length Tm
function sim_firm_pols(s_init, Tm, i_pol_col, mp_pol_col, Kep_pol_col, ie_pol_col, y_pol_col, modl, sspace, opts)
    # A. Globals
    s       = sspace.s
    ns      = size(s,1)

    # Maximum length of simulation
    if Tm > size(i_pol_col,2)
        T = size(i_pol_col,2)
    else
        T = Tm
    end

    # Initialize paths
    i_path, mp_path, ie_path, y_path = Array{Float64}(undef, T), Array{Float64}(undef, T), Array{Float64}(undef, T), Array{Float64}(undef, T)

    # Find closest point in mgrid to m_init
    i_state = argmin(abs.(m_init*ones(ns)-sspace.mgrid))

    for tt=1:T
        i_path[tt], mp_path[tt], ie_path[tt], y_path[tt] = i_pol_col[i_state, tt], mp_pol_col[i_state, tt], ie_pol_col[i_state, tt], y_pol_col[i_state, tt]
        # Update next period's state
        i_state = argmin(abs.(mp_path[tt]*ones(ns)-sspace.mgrid))
    end

    return FirmSimResults(i_path, mp_path, ie_path, y_path)
end

# Alternative method with two sets of policies for adjustment vs not, with cost draws!
function sim_firm_pols(s_init, Tm, xi_draws, i_pol_col1, mp_pol_col1, Kep_pol_col1, ie_pol_col1, y_pol_col1, i_pol_col0, mp_pol_col0, Kep_pol_col0, ie_pol_col0, y_pol_col0, xi_col, modl, sspace, opts)
    # A. Globals
    s       = sspace.s
    ns      = size(s,1)

    # Maximum length of simulation
    if Tm > size(i_pol_col1,2)
        T = size(i_pol_col1,2)
    else
        T = Tm
    end

    # Compile the policies into one "policy" based on xi_draws being below xi_col
    xi_draws_mat = repeat(xi_draws', sspace.Ns, 1)
    # Keep calling the "adjustment" indicator "inv_opp_mat", determined based on comparing the matrix of xi_draws and xi_col-s
    inv_opp_mat  = xi_draws_mat .< xi_col
    # Put everything together based on this
    i_pol_col    = inv_opp_mat .* i_pol_col1 + (ones(size(i_pol_col1)) - inv_opp_mat) .* i_pol_col0
    mp_pol_col   = inv_opp_mat .* mp_pol_col1 + (ones(size(mp_pol_col1)) - inv_opp_mat) .* mp_pol_col0
    Kep_pol_col  = inv_opp_mat .* Kep_pol_col1 + (ones(size(Kep_pol_col1)) - inv_opp_mat) .* Kep_pol_col0
    ie_pol_col   = inv_opp_mat .* ie_pol_col1 + (ones(size(ie_pol_col1)) - inv_opp_mat) .* ie_pol_col0
    y_pol_col    = inv_opp_mat .* y_pol_col1 + (ones(size(y_pol_col1)) - inv_opp_mat) .* y_pol_col0

    # Initialize paths
    i_path, mp_path, Kep_path, ie_path, y_path, Ei_path, Ei_prob_path = Array{Float64}(undef, T), Array{Float64}(undef, T), Array{Float64}(undef, T), Array{Float64}(undef, T), Array{Float64}(undef, T), Array{Float64}(undef, T), Array{Float64}(undef, T)

    # Find closest point in sgrid to s_init
    i_state = argmin(sum((repeat(s_init, ns).-sspace.s).^2, dims=2))[1]

    for tt=1:T
        Kep_path[tt], i_path[tt], mp_path[tt], ie_path[tt], y_path[tt] = Kep_pol_col[i_state, tt], i_pol_col[i_state, tt], mp_pol_col[i_state, tt], ie_pol_col[i_state, tt], y_pol_col[i_state, tt]
        Ei_path[tt] = inv_opp_mat[i_state, tt]
        Ei_prob_path[tt] = xi_cdfun(xi_col[i_state,tt], modl, sspace, opts)
        KepdK_next  = Kep_path[tt]/(1.0-modl.delta+i_path[tt])
        # Update next period's state
        # i_state = argmin(abs.(mp_path[tt]*ones(ns)-sspace.mgrid))
        i_state = argmin(sum((repeat([mp_path[tt] KepdK_next], ns).-sspace.s).^2, dims=2))[1]
    end

    return FirmSimResults(i_path, mp_path, Kep_path, ie_path, y_path, Ei_path, Ei_prob_path)
end

# Policy function interpolation based alternative method with two sets of policies for adjustment vs not, with cost draws!
# The arguments must be collections of INTERPOLANT COEFFICIENTS!!! (In the linear interpolation case, these are equal to the policies, but in the cubic case, the coefficients are different than the policies themselves.)
function sim_firm_pols_intp(s_init, Tm, xi_draws, i_cpol_col1, mp_cpol_col1, Kep_cpol_col1, ie_cpol_col1, y_cpol_col1, i_cpol_col0, mp_cpol_col0, Kep_cpol_col0, ie_cpol_col0, y_cpol_col0, xi_c_col, modl, sspace, opts)
    # A. Globals
    s       = sspace.s
    ns      = size(s,1)

    # Maximum length of simulation
    if Tm > size(i_cpol_col1,2)
        T = size(i_cpol_col1,2)
    else
        T = Tm
    end

    # Initialize paths
    i_path, mp_path, Kep_path, ie_path, y_path, Ei_path, Ei_prob_path = Array{Float64}(undef, T), Array{Float64}(undef, T), Array{Float64}(undef, T), Array{Float64}(undef, T), Array{Float64}(undef, T), Array{Float64}(undef, T), Array{Float64}(undef, T)

    # Initialize state
    s_cur = s_init

    for tt=1:T
        xi_cutoff_cur = funeval(xi_c_col[:,tt], sspace.fspace, s_cur)[1][1]
        if xi_draws[tt] < xi_cutoff_cur[1]
            Kep_path[tt], i_path[tt], mp_path[tt], ie_path[tt], y_path[tt] = funeval(Kep_cpol_col1[:,tt], sspace.fspace, s_cur)[1][1], funeval(i_cpol_col1[:,tt], sspace.fspace, s_cur)[1][1], funeval(mp_cpol_col1[:,tt], sspace.fspace, s_cur)[1][1], funeval(ie_cpol_col1[:,tt], sspace.fspace, s_cur)[1][1], funeval(y_cpol_col1[:,tt], sspace.fspace, s_cur)[1][1]
            Ei_path[tt] = 1;
        else
            Kep_path[tt], i_path[tt], mp_path[tt], ie_path[tt], y_path[tt] = funeval(Kep_cpol_col0[:,tt], sspace.fspace, s_cur)[1][1], funeval(i_cpol_col0[:,tt], sspace.fspace, s_cur)[1][1], funeval(mp_cpol_col0[:,tt], sspace.fspace, s_cur)[1][1], funeval(ie_cpol_col0[:,tt], sspace.fspace, s_cur)[1][1], funeval(y_cpol_col0[:,tt], sspace.fspace, s_cur)[1][1]
            Ei_path[tt] = 0;
        end
        Ei_prob_path[tt] = xi_cdfun(xi_cutoff_cur, modl, sspace, opts)
        # Update next period's state
        KepdK_next  = Kep_path[tt]/(1.0-modl.delta+i_path[tt])
        s_cur       = [mp_path[tt] KepdK_next];
    end

    return FirmSimResults(i_path, mp_path, Kep_path, ie_path, y_path, Ei_path, Ei_prob_path)
end


###
# Stationary distribution and steady state equilibrium
###
# Simply solve stationary distribution in the m/K space
function solve_L(c_sol_ss, modl, sspace, opts)
    manprintln("Solving for stationary distribution.", "mod_savedout/"*modl.name*"_output.txt")
    sf0     = sspace.sf0
    Nd      = size(sf0,1)

    # Solve firm's problem given solution in "c_sol_ss" on fine grid
    vf1,     = solve_valfunc_v(c_sol_ss, modl.psi, modl.rm, sf0, modl, sspace, opts, false)
    vf0,     = solve_valfunc_v0(c_sol_ss, modl.psi, modl.rm, sf0, modl, sspace, opts, false)

    # Compute stationary distribution
    mp_pol1     = vf1.mp
    mp_pol0     = vf0.mp
    Kepdk_pol1  = vf1.Kep ./ (1.0 - modl.delta .+ vf1.i)
    Kepdk_pol0  = vf0.Kep ./ (1.0 - modl.delta .+ vf0.i)

    # Compute cost cutoffs
    xi_vec   = vf1.val - vf0.val
    # Probabilities of issuing
    Ei_prob  = reshape(xi_cdfun(xi_vec, modl, sspace, opts),Nd,1)

    fspaceergm  = fundef([:spli, sspace.mgridf, 0, 1])
    fspaceergKe = fundef([:spli, sspace.Kegridf, 0, 1])
    Qm1         = funbase(fspaceergm, mp_pol1)
    Qm1msh      = funbase(fspaceergm, mp_pol1+ones(Nd)*modl.x_msh)
    QKe1        = funbase(fspaceergKe, Kepdk_pol1)
    Qm0         = funbase(fspaceergm, mp_pol0)
    Qm0msh      = funbase(fspaceergm, mp_pol0+ones(Nd)*modl.x_msh)
    QKe0        = funbase(fspaceergKe, Kepdk_pol0)

    # Construct product of QKe*Qm
    Q_1     = row_kron(QKe1,Qm1)
    Q_1msh  = row_kron(QKe1,Qm1msh)
    Q_0     = row_kron(QKe0,Qm0)
    Q_0msh  = row_kron(QKe1,Qm0msh)

    # Aggregate everything up, depending on shocks
    Q       = (1.0-modl.prob_msh) * (row_kron(reshape(1.0.-Ei_prob,Nd,1), Q_0) + row_kron(sparse(reshape(Ei_prob,Nd,1)), Q_1)) + modl.prob_msh * (row_kron(reshape(1.0.-Ei_prob,Nd,1), Q_0msh) + row_kron(sparse(reshape(Ei_prob,Nd,1)), Q_1msh))

    # Generating the mass of firms entering as incumbents -- to save on computations, could also do inside "setup!"
    Qme         = funbase(fspaceergm, ones(Nd)*modl.m_0)
    QKee        = funbase(fspaceergKe, ones(Nd)*0.0)
    # Multiply up into "full" transition matrix
    Qe          = row_kron(QKee, Qme)

    Le          = sparse(Qe' * ((1.0-modl.pi) * ones(Nd) / Nd))

    # Initialize "Lbar" coming in from previous period -- DON'T MULTIPLY BY "pi" HERE, BUT IN LAW OF MOTION
    L           = ones(size(Q,1))
    L           = sparse(L/sum(L))
    Lnew        = copy(L) # Initialize Lnew also to refer outside

    for itL = 1:opts.itermaxL
        Lnew    = modl.pi*Q'*L + Le         # "Decision-makers" L from t-1, iterated forward with "decisions" Q, only "pi" survive + entrants
        dL      = norm(Lnew-L)/norm(L)
        if dL < opts.tolL; manprintln("Convergence of m distribution in $itL iterations.", "mod_savedout/"*modl.name*"_output.txt"); break; end
        if mod(itL,20) == 0
            if opts.prnt == "Y"
                manprintln("iter: \t $itL, dLm: \t $dL", "mod_savedout/"*modl.name*"_output.txt")
            end
        end

        # Plot marginals along convergence
        if (mod(itL,4) == 0) & (itL<200) & (opts.plotSD == "Y") & false
            Lm  = kron(ones(1,sspace.nf[3]), I(sspace.nf[2])) * L
            LKe = kron(I(sspace.nf[3]), ones(1,sspace.nf[2])) * L
            if true
                hh=Plots.bar([sspace.mgridf,sspace.Kegridf],[Lm,LKe], layout=(2,1), label=[L"a^{b}" L"s"])
                Plots.savefig(hh,"mod_savedout/conv_figs/"*modl.name*"_solve_L_it$(itL).png")
            end
        end
        L       = Lnew
    end

    # Create stationary distribution (marginals) after convergence
    if opts.plotSD == "Y"
        Lm  = kron(ones(1,sspace.nf[3]), I(sspace.nf[2])) * Lnew
        LKe = kron(I(sspace.nf[3]), ones(1,sspace.nf[2])) * Lnew
        if true
            hh=Plots.bar([sspace.mgridf,sspace.Kegridf],[Lm,LKe], layout=(2,1), label=[L"a^{b}" L"s"])
            Plots.savefig(hh,"mod_savedout/conv_figs/"*modl.name*"_solve_L.png")
        end
    end

    return Lnew
end


function menufun(flag, m_cur, i_temp, y_temp, Kep_temp, modl, sspace, opts)
    nu, eta, delta         = modl.nu, modl.eta, modl.delta

    # ADJUSTMENT COSTS INDEPENDENT OF PE VS GE
    # 2. Adjustment costs
    function ACfun(ival)
        if size(ival,1) == 1
            return ival + nu*(ival-modl.delta_0).^eta
        else
            return ival + nu*(ival-modl.delta_0*ones(size(i_temp,1))).^eta
        end
    end

    if flag =="adj_cost"
        out = ACfun(i_temp)
    end

    return out
end

# Menu function for the probability of issuance given a vector/scalar of cutoffs x
function xi_cdfun(x::Array{Float64,1}, modl, sspace, opts)
    # Allow for no issuance costs
    if modl.xi_ub==0.0
        if size(x,1) == 1
            return 1.0
        else
            return ones(size(x,1))
        end
    end

    if true
        # Firstly, simply apply the cdf function to x, and then set negative numbers to zero, and numbers above 1 to 1
        probs = (x - zeros(size(x)))./(modl.xi_ub*ones(size(x)) - zeros(size(x)))
        probs = min.(max.(zeros(size(x)), probs), ones(size(x)))
    else
        # Repeat instead for log Normal (ln(xi_ub), lnsig=1)
        mu_here     = log(modl.xi_ub)
        lnsig_here  = 1.0
        probs = cdf.(LogNormal(mu_here, lnsig_here), x)
    end

    return probs
end
function xi_cdfun(x::Float64, modl, sspace, opts)
    # Allow for no issuance costs
    if modl.xi_ub==0.0
        if size(x,1) == 1
            return 1.0
        else
            return ones(size(x,1))
        end
    end

    if true
        # Firstly, simply apply the cdf function to x, and then set negative numbers to zero, and numbers above 1 to 1
        probs = (x - 0.0)/(modl.xi_ub - 0.0)
        probs = min.(max.(0.0, probs), 1.0)
    else
        # Repeat instead for log Normal (ln(xi_ub), lnsig=1)
        mu_here     = log(modl.xi_ub)
        lnsig_here  = 1.0
        probs = cdf(LogNormal(mu_here, lnsig_here), x)
    end

    return probs
end
# Menu function for the conditional expectation of cost given it's below a vector/scalar of cutoffs x
function xi_condexpn(x::Array{Float64,1}, modl, sspace, opts)

    if true
        # Firstly, simply apply the expectation function to x, and then set "low" realizations to zero, and "high" realizations to unconditional mean
        expns = 0.5 * (x + zeros(size(x)))
        expns = min.(max.(zeros(size(x)), expns), 0.5*ones(size(x))*(0.0 + modl.xi_ub))
    else
        mu_here     = log(modl.xi_ub)
        lnsig_here  = 1.0
        x = max.(x, zeros(size(x)))
        expns = exp(mu_here + (lnsig_here^2)/2) * cdf.(Normal(), (log.(x) - ones(size(x))*(mu_here + lnsig_here^2))/lnsig_here) ./ cdf.(Normal(), (log.(x) - ones(size(x))*mu_here)/lnsig_here)
        # Set Nan to 0
        ind_expns_nan       = isnan.(expns)
        num_isnan           = sum(ind_expns_nan)
        expns[ind_expns_nan]= zeros(num_isnan)
    end

    return expns
end
function xi_condexpn(x::Float64, modl, sspace, opts)

    if true
        # Firstly, simply apply the expectation function to x, and then set "low" realizations to zero, and "high" realizations to unconditional mean
        expns = 0.5 * (x + 0.0)
        expns = min.(max.(0.0, expns), 0.5*(0.0 + modl.xi_ub))
    else
        mu_here     = log(modl.xi_ub)
        lnsig_here  = 1.0
        x = max(x, 0.0)
        expns = exp(mu_here + (lnsig_here^2)/2) * cdf(Normal(), (log(x) - (mu_here + lnsig_here^2))/lnsig_here) / cdf(Normal(), (log(x) - mu_here)/lnsig_here)
        if isnan(expns)
            expns = 0.0
        end
    end

    return expns
end


end # End module
